from read import read_trj, read_intc, read_energy
import subprocess
import re
import sys, os

import numpy as np
from subprocess import Popen, PIPE
from math import pi
norm = np.linalg.norm



def run_oneway_commitor(iiter, name, xyzflat, vxyzflat, temp, steps, timestep, damp):
    filename = f"{name}"
    write_data(filename, xyzflat, vxyzflat)
    write_inscript(iiter, filename, indat=filename, outdat=f"{filename}-out", temp=temp, steps=steps, timestep=timestep, damp=damp)
    lmp = os.environ.get('LMP_EXEC', 'lmp_gcc')
    command = f"mpirun -np 1 {lmp} -in {filename}.in"

    # print("command: ", command)
    try:
        res = subprocess.check_output(command.split(),
            stderr=subprocess.STDOUT)
    except:
        raise NameError("fail to compute commit")

    newxyz, newvxyz = read_trj(name)
    intcoord = read_intc(name)
    alle = read_energy(name)
    basin, ends = find_basin(intcoord)

    if (alle.shape[0]!=intcoord.shape[0]):
        raise NameError(f"alle.shape[0]{alle.shape[0]} != intcoord.shape[0]{intcoord.shape[0]}")
    if (alle.shape[0]!=newxyz.shape[0]):
        raise NameError(f"alle.shape[0]{alle.shape[0]} != newxyz.shape[0]{newxyz.shape[0]}")

    os.remove(f"{name}-plumed.out")
    os.remove(f"{name}.intc")

    return newxyz, newvxyz, alle, intcoord, basin, ends

def find_firste():
    pattern="Step"
    info = []
    with open("log.lammps") as fin:
        here=False
        for line in fin:
            if (here is True):
                info = np.array(list(map(float, line.split()[1:5])))
                here = False
            if re.search(pattern, line):
                here = True
    return info


def find_basin(intc):

    # # damp500
    # C5 = np.array([[-2.9,2.44], [-2.05,3.15]])
    # C7eq = np.array([[-1.61,0.48], [-1.08,1.4]])
    # Cax= np.array([[0.8,-1.2], [1.25, -0.26]])

    # damp 50
    C5 = np.array([[-2.88,2.63], [-2.21, 2.89]])
    C7eq = np.array([[-1.6,0.48], [-1.7,1.3]])
    Cax= np.array([[0.8,-1.1], [1.21, -0.32]])

    center = np.array([[(C5[0, 0]+C5[1, 0])/2., (C5[0, 1]+C5[1, 1])/2.],
              [(C7eq[0, 0]+C7eq[1, 0])/2., (C7eq[0, 1]+C7eq[1, 1])/2.],
              [(Cax[0, 0]+Cax[1, 0])/2., (Cax[0, 1]+Cax[1, 1])/2.]])
    radius = np.array([[(-C5[0, 0]+C5[1, 0])/2., (-C5[0, 1]+C5[1, 1])/2.],
              [(-C7eq[0, 0]+C7eq[1, 0])/2., (-C7eq[0, 1]+C7eq[1, 1])/2.],
              [(-Cax[0, 0]+Cax[1, 0])/2., (-Cax[0, 1]+Cax[1, 1])/2.]])


    ends = -1
    basin = -1
    found_commit = False
    allconfig = intc[:, 1:3]
    for iconfig in range(intc.shape[0]):
        if (allconfig[iconfig, 1] < -2.5):
            allconfig[iconfig, 1] += 2*pi
        if (allconfig[iconfig, 0] > 3):
            allconfig[iconfig, 0] -= 2*pi

    dist0 = norm((allconfig-center[0, :])/radius[0, :], axis=1)
    dist1 = norm((allconfig-center[1, :])/radius[1, :], axis=1)
    dist2 = norm((allconfig-center[2, :])/radius[2, :], axis=1)
    for iconfig in range(1, intc.shape[0]):
        if (found_commit == False):
            if (dist0[iconfig]<1):
                basin = 0
                found_commit = True
                ends = iconfig
            elif (dist1[iconfig]<1):
                basin = 1
                found_commit = True
                ends = iconfig
            elif (dist2[iconfig]<1):
                basin = 2
                found_commit = True
                ends = iconfig
    if (found_commit == True):
        print("find basin", basin, ends, allconfig[ends], allconfig[-1])

    if (basin == -1):
        dist0 = norm((allconfig[-1]-center[0, :])/radius[0, :])
        dist1 = norm((allconfig[-1]-center[1, :])/radius[1, :])
        dist2 = norm((allconfig[-1]-center[2, :])/radius[2, :])
        t=intc[-1, 1:3]
        print(f"cannot find basin {t} {allconfig[-1]} {dist0} {dist1} {dist2}")
        # raise NameError(f"cannot find basin {t} {allconfig[-1]} {dist0} {dist1} {dist2}")

    # raise NameError(f"test {ends} {basin} {dist0[ends]} {dist1[ends]} {dist2[ends]}")
    return basin, ends


def write_data(name, xyzflat, vxyzflat0):
    ele_type = [1, 2, 1, 1, 3, 4, 5, 6, 2, 7, 2, 1, 1, 1, 3, 4, 5, 6,
            2, 7, 7, 7]
    ele_q = [0.1123, -0.3662, 0.1123, 0.1123, 0.5972, -0.5679, -0.4157,
            0.2719, 0.0337, 0.0823, -0.1825, 0.0603, 0.0603, 0.0603,
            0.5973, -0.5679, -0.4157, 0.2719, -0.149, 0.0976, 0.0976,
            0.0976]
    vxyzflat = vxyzflat0
    natom = int(xyzflat.shape[0]/3.)
    with open(f"{name}.dat", "w+") as fout:
        print(before_xyz, file=fout)
        print("Atoms\n", file=fout)
        for i in range(natom):
            print(i+1, "1", ele_type[i], ele_q[i],
                    xyzflat[i*3], xyzflat[i*3+1], xyzflat[i*3+2],
                    file=fout)
        print("", file=fout)
        print("Velocities\n", file=fout)
        for i in range(natom):
            print(i+1, vxyzflat[i*3], vxyzflat[i*3+1],
                    vxyzflat[i*3+2],
                    file=fout)
        print("", file=fout)
        print(after_xyz, file=fout)

def write_inscript(iiter=0, name="input", indat="in", outdat="out", temp=300,
        steps=2000, timestep=0.5, damp=100):

    script=f"""units		real

neigh_modify    once yes one 22 page 2200

atom_style	full
bond_style      harmonic
angle_style     harmonic
dihedral_style  harmonic
pair_style      lj/cut/coul/cut 10.0
pair_modify     mix arithmetic
special_bonds   amber
kspace_style    none
read_data       {indat}.dat

thermo          1
thermo_style custom step temp pe ke etotal press

dump            1 all custom 1 {name}.lammpstrj id type x y z vx vy vz
fix             1 all nve
fix             2 all langevin {temp} {temp} {damp} 3214{iiter}
variable e     equal pe
variable k     equal ke
variable t     equal etotal
variable T     equal temp
fix extra all print 1 "$T $e $k $t" file {name}.energy
fix        plumed all plumed plumedfile {name}-plumed.dat outfile {name}-plumed.out
timestep        {timestep}
run             {steps}   #1 ns
"""

    plumed_script = f"""t1: TORSION ATOMS=2,5,7,9 NOPBC
t2: TORSION ATOMS=5,7,9,15 NOPBC
t3: TORSION ATOMS=7,9,15,17 NOPBC
t4: TORSION ATOMS=9,15,17,19 NOPBC
t5: TORSION ATOMS=17,15,9,11 NOPBC
t6: TORSION ATOMS=11,9,7,5 NOPBC
d1: DISTANCE ATOMS=19,17 NOPBC
d2: DISTANCE ATOMS=17,15 NOPBC
d3: DISTANCE ATOMS=15,9 NOPBC
d4: DISTANCE ATOMS=9,11 NOPBC
d5: DISTANCE ATOMS=9,7 NOPBC
d6: DISTANCE ATOMS=7,5 NOPBC
d7: DISTANCE ATOMS=5,2 NOPBC
d8: DISTANCE ATOMS=14,16 NOPBC
d9: DISTANCE ATOMS=5,6 NOPBC
a1: ANGLE ATOMS=9,17,15 NOPBC
a2: ANGLE ATOMS=17,15,9 NOPBC
a3: ANGLE ATOMS=15,9,7 NOPBC
a4: ANGLE ATOMS=9,7,5 NOPBC
a5: ANGLE ATOMS=7,5,2 NOPBC
a6: ANGLE ATOMS=15,9,11 NOPBC
a7: ANGLE ATOMS=11,9,7 NOPBC
PRINT ARG=* STRIDE=1 FILE={name}.intc
FLUSH STRIDE=1
"""

    with open(f"{name}.in", "w+") as fout:
        print(script, file=fout)
    with open(f"{name}-plumed.dat", "w+") as fout:
        print(plumed_script, file=fout)

before_xyz = """ LAMMPS data file for ACE

22 atoms
21 bonds
36 angles
66 dihedrals

7 atom types
8 bond types
16 angle types
19 dihedral types

-100 100  xlo xhi
-100 100  ylo yhi
-100 100  zlo zhi

Masses

1 1.008
2 12.01
3 12.01
4 16.0
5 14.01
6 1.008
7 1.008

"""
after_xyz = """
Bonds

 1 3 2   3
 2 3 2   4
 3 3 1   2
 4 3 11 12
 5 3 11 13
 6 3 11 14
 7 5 9  10
 8 7 7   8
 9 5 19 20
10 5 19 21
11 5 19 22
12 7 17 18
13 1 5   6
14 2 5   7
15 4 2   5
16 1 15 16
17 2 15 17
18 6 9  11
19 4 9  15
20 8 7   9
21 8 17 19

Angles

1 2 5 7 8
2 4 4 2 5
3 5 3 2 4
4 4 3 2 5
5 5 1 2 3
6 5 1 2 4
7 4 1 2 5
8 2 15 17 18
9 5 13 11 14
10 5 12 11 13
11 5 12 11 14
12 9 10 9 11
13 10 10 9 15
14 11 9 11 12
15 11 9 11 13
16 11 9 11 14
17 12 8 7 9
18 13 7 9 10
19 16 21 19 22
20 16 20 19 21
21 16 20 19 22
22 12 18 17 19
23 13 17 19 20
24 13 17 19 21
25 13 17 19 22
26 1 6 5 7
27 3 5 7 9
28 6 2 5 6
29 7 2 5 7
30 1 16 15 17
31 3 15 17 19
32 8 11 9 15
33 6 9 15 16
34 7 9 15 17
35 14 7 9 11
36 15 7 9 15

Dihedrals

1 1 6 5 7 8
2 2 6 5 7 8
3 3 5 7 9 10
4 10 4 2 5 6
5 3 4 2 5 6
6 11 4 2 5 6
7 3 4 2 5 7
8 10 3 2 5 6
9 3 3 2 5 6
10 11 3 2 5 6
11 3 3 2 5 7
12 2 2 5 7 8
13 10 1 2 5 6
14 3 1 2 5 6
15 11 1 2 5 6
16 3 1 2 5 7
17 1 16 15 17 18
18 2 16 15 17 18
19 3 15 17 19 20
20 3 15 17 19 21
21 3 15 17 19 22
22 12 14 11 9 15
23 12 13 11 9 15
24 12 12 11 9 15
25 12 10 9 11 12
26 12 10 9 11 13
27 12 10 9 11 14
28 10 10 9 15 16
29 11 10 9 15 16
30 3 10 9 15 17
31 2 9 15 17 18
32 3 8 7 9 10
33 3 8 7 9 11
34 3 8 7 9 15
35 12 7 9 11 12
36 12 7 9 11 13
37 12 7 9 11 14
38 3 18 17 19 20
39 3 18 17 19 21
40 3 18 17 19 22
41 19 5 9 7 8
42 19 15 19 17 18
43 2 6 5 7 9
44 1 5 7 9 11
45 4 5 7 9 11
46 5 5 7 9 11
47 6 5 7 9 11
48 7 5 7 9 15
49 8 5 7 9 15
50 9 5 7 9 15
51 6 5 7 9 15
52 2 2 5 7 9
53 2 16 15 17 19
54 3 11 9 15 16
55 13 11 9 15 17
56 14 11 9 15 17
57 5 11 9 15 17
58 6 11 9 15 17
59 2 9 15 17 19
60 3 7 9 15 16
61 15 7 9 15 17
62 16 7 9 15 17
63 17 7 9 15 17
64 6 7 9 15 17
65 18 2 7 5 6
66 18 9 17 15 16

Pair Coeffs

1 0.01570000002623629  2.6495327872602221
2 0.10939999991572773  3.3996695084507409
3 0.086000000128358844 3.3996695079448309
4 0.20999999984182244  2.9599219016446874
5 0.16999999991766696  3.2499985240310356
6 0.015700000004219245 1.0690784617205229
7 0.015700000098461422 2.4713530426421655

Bond Coeffs

1 570.0 1.229
2 490.0 1.335
3 340.0 1.090
4 317.0 1.522
5 340.0 1.090
6 310.0 1.526
7 434.0 1.010
8 337.0 1.449

Angle Coeffs

1 80.0 122.90005267195104
2 50.0 120.00005142908158
3 50.0 121.90005224337536
4 50.0 109.50004692903693
5 35.0 109.50004692903693
6 80.0 120.40005160051184
7 70.0 116.60004997192425
8 63.0 111.10004761475803
9 50.0 109.50004692903693
10 50.0 109.50004692903693
11 50.0 109.50004692903693
12 50.0 118.04005047448166
13 50.0 109.50004692903693
14 80.0 109.70004701475206
15 63.0 110.10004718618234
16 35.0 109.50004692903693

Dihedral Coeffs

 1 2.0          1 1
 2 2.5         -1 2
 3 0.0          1 2
 4 2.0          1 2
 5 0.4          1 3
 6 0.0          1 4
 7 0.0          1 1
 8 0.272        1 2
 9 0.43         1 3
10 0.8          1 1
11 0.08        -1 3
12 0.155555556  1 3
13 0.20         1 1
14 0.20         1 2
15 0.45        -1 1
16 1.58        -1 2
17 0.55        -1 3
18 10.5        -1 2
19 1.10        -1 2
"""
